Random walks in the latent space

This notebook reproduces experiment 4 in Arvanitidis et al. (2017). We train a convolutional VAE on frames of a video and visualize random walks in the latent space. These walks can be computed using either the Euclidean or the Riemannian metric. Since the Riemannian metric also takes the generator's variance into account, the random walk using the Riemannian metric will avoid regions of high variance.

Imports and setup

In [1]:
# Imports and setup of plotting library
%load_ext autoreload
%autoreload 2
%matplotlib inline
from copy import deepcopy

import numpy as np
import tensorflow as tf
from tensorflow.python.keras.datasets import mnist
from tensorflow.python.keras import Sequential, Model
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.layers import Dense, Input, Lambda, Conv2D
from tensorflow.python.keras.layers import Conv2DTranspose, Flatten, Reshape
from tensorflow.python.keras.constraints import NonNeg
from tensorflow.python.keras.initializers import RandomUniform
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot

from src.vae import VAE
from src.rbf import RBFLayer
from src.videoio import get_frames, load_from_pngs
from src.plot import plot_latent_curve_iterations, plot_magnification_factor

# Set up plotly
init_notebook_mode(connected=True)
layout = go.Layout(
    width=700,
    height=500,
    margin=go.Margin(l=60, r=60, b=40, t=20),
    showlegend=False
)
config={'showLink': False}

# Make results completely repeatable
seed = 0
np.random.seed(seed)
tf.set_random_seed(seed)
/Users/kilian/dev/tum/2018-mlic-kilian/venv/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
In [2]:
debug = False
train_epochs = 10 if debug else 600

Construct a VAE

using the description in Appendix D in the paper. I have found subtracting the mean from the input data to cause slightly better reconstructions. Thus, the decoder's mean network has a tanh activation function in it's last layer. Apart from this, the VAE below matches exactly the description in the paper.

In [3]:
# Implementation details from Appendix D
input_shape = (64, 64, 3)
latent_dim = 2
l2_reg = tf.keras.regularizers.l2(1e-5)

# Create the encoder models
enc_input = Input(input_shape)
enc_shared = Sequential([
    Conv2D(32, (3, 3), strides=(2, 2), activation='tanh', padding='same',
           input_shape=input_shape, kernel_regularizer=l2_reg),
    Conv2D(32, (3, 3), strides=(2, 2), activation='tanh', padding='same',
           kernel_regularizer=l2_reg),
    Flatten()
])
enc_mean = Sequential([
    enc_shared,
    Dense(1024, activation='tanh', kernel_regularizer=l2_reg),
    Dense(2, activation='linear', kernel_regularizer=l2_reg)
])
enc_var = Sequential([
    enc_shared,
    Dense(1024, activation='tanh', kernel_regularizer=l2_reg),
    Dense(2, activation='softplus', kernel_regularizer=l2_reg)
])
enc_mean = Model(enc_input, enc_mean(enc_input))
enc_var = Model(enc_input, enc_var(enc_input))

# Create the decoder models
dec_input = Input((latent_dim,))
dec_mean = Sequential([
    Dense(1024, activation='tanh', kernel_regularizer=l2_reg),
    Dense(16 * 16 * 3, activation='tanh', kernel_regularizer=l2_reg),
    Reshape((16, 16, 3)),
    Conv2DTranspose(32, (3, 3), strides=(2, 2), activation='tanh', 
                    padding='same', kernel_regularizer=l2_reg),
    Conv2DTranspose(32, (3, 3), strides=(2, 2), activation='tanh', 
                    padding='same', kernel_regularizer=l2_reg),
    Conv2DTranspose(3, (3, 3), strides=(1, 1), activation='tanh', 
                    padding='same', kernel_regularizer=l2_reg),
    Conv2D(3, (3, 3), strides=(1, 1), activation='tanh', padding='same',
           kernel_regularizer=l2_reg)
])

# Build the RBF network
num_centers = 64
a = 2.0
rbf = RBFLayer([32, 32, 3], num_centers)
var_constraint = NonNeg()
dec_var = Sequential([
    rbf,
    Conv2DTranspose(1, (3, 3), strides=(2, 2), activation='linear',
                    padding='same', kernel_constraint=NonNeg(), 
                    bias_constraint=NonNeg(),
                    kernel_initializer=RandomUniform(minval=0, maxval=0.05),
                    kernel_regularizer=l2_reg),
    Conv2D(3, (3, 3), strides=(1, 1), activation='linear',
                    padding='same', kernel_constraint=var_constraint,
                    bias_constraint=var_constraint,
                    kernel_initializer=RandomUniform(minval=0, maxval=0.05),
                    kernel_regularizer=l2_reg),
])
dec_mean = Model(dec_input, dec_mean(dec_input))
dec_var = Model(dec_input, dec_var(dec_input))

vae = VAE(enc_mean, enc_var, dec_mean, dec_var, dec_stddev=1.)
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
WARNING:tensorflow:Output "model_6" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_6" during training.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
sequential_2 (Sequential)       (None, 2)            8401826     input_1[0][0]                    
__________________________________________________________________________________________________
sequential_3 (Sequential)       (None, 2)            8401826     input_1[0][0]                    
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 2)            0           sequential_2[1][0]               
                                                                 sequential_3[1][0]               
==================================================================================================
Total params: 16,793,508
Trainable params: 16,793,508
Non-trainable params: 0
__________________________________________________________________________________________________
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            (None, 2)            0                                            
__________________________________________________________________________________________________
sequential_4 (Sequential)       (None, 64, 64, 3)    801367      input_2[0][0]                    
__________________________________________________________________________________________________
sequential_5 (Sequential)       (None, 64, 64, 3)    196858      input_2[0][0]                    
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 64, 64, 3)    0           sequential_4[1][0]               
                                                                 sequential_5[1][0]               
==================================================================================================
Total params: 998,225
Trainable params: 801,367
Non-trainable params: 196,858
__________________________________________________________________________________________________

Load the frames

and subtract the mean.

In [4]:
x_train = load_from_pngs('~/Desktop/trump-cut/')
    
# Shuffle the training data, but save the permutation for later
permutation = np.random.permutation(len(x_train))
x_train = x_train[permutation]

# Subtract the mean
x_mean = np.mean(x_train, axis=0)
x_train -= x_mean
plt.imshow(x_mean)
In [5]:
x_plot = x_train[10]
plt.imshow(x_plot + x_mean)
Out[5]:
<matplotlib.image.AxesImage at 0x11d24c6d8>

Train the VAE

In [6]:
history = vae.model.fit(x_train,
              epochs=train_epochs,
              batch_size=32,
              validation_split=0.1,
              verbose=0)

# Plot the losses
data = [go.Scatter(y=history.history['loss'], name='Train Loss'),
       go.Scatter(y=history.history['val_loss'], name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)

Save and reload the model

Training takes about 8 hours, so we can simply reload the trained VAE here. This is why the execution count is suddenly larger in the cells below.

In [10]:
if not debug:
    vae.encoder.save('models/video-encoder.h5', include_optimizer=False)
    vae.decoder.save('models/video-generator.h5', include_optimizer=False)
In [220]:
from src.vae import load_from

vae = load_from('models/video-encoder.h5', 'models/video-generator.h5')
rbf = vae.decoder.layers[2].layers[0]
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file: the model was *not* compiled. Compile it manually.
WARNING:tensorflow:Output "model_182" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_182" during training.
WARNING:tensorflow:Output "model_182" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_182" during training.
WARNING:tensorflow:Output "model_182" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_182" during training.

Visualize the latent representations

Both in the paper and here, we see chains of latent points indicating that frames in a sequence end up next to each other in the latent space.

In [221]:
# Display a 2D plot of the classes in the latent space
sampled, encoded_mean, encoded_var = vae.encoder.predict(x_train)

# Plot
scatter_plot = go.Scatter(
    x = encoded_mean[:, 0],
    y = encoded_mean[:, 1],
    mode = 'markers',
    marker = {'color': 'orange'}
)
data = [scatter_plot]
iplot(go.Figure(data=data, layout=layout), config=config)

Show a reconstructed video sequence

In [222]:
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
from moviepy.video.io.html_tools import ipython_display
from scipy.misc import imresize

# Display a sequence and it's reconstructed version side by side
seq_length = 1000
# Invert the permutation with np.argsort
seq_indices = np.argsort(permutation)[:seq_length]
sequence = x_train[seq_indices]
_, reconstructed, _ = vae.decoder.predict(encoded_mean[seq_indices])

frames = []
for i in range(len(sequence)):
    frame = np.concatenate([sequence[i], reconstructed[i]], axis=1)
    frame += np.concatenate([x_mean, x_mean], axis=1)
    frame = np.clip(frame, 0, 1) * 255.0
    # Scale from 64x64 per image to 256x256
    frame = imresize(frame, 4.0, 'nearest')
    frames.append(frame)
    
clip = ImageSequenceClip(sequence=frames, fps=30)
ipython_display(clip)
  0%|          | 0/1001 [00:00<?, ?it/s]
 10%|▉         | 98/1001 [00:00<00:00, 970.77it/s]
 22%|██▏       | 219/1001 [00:00<00:00, 1089.16it/s]
 34%|███▍      | 343/1001 [00:00<00:00, 1137.03it/s]
 47%|████▋     | 467/1001 [00:00<00:00, 1161.22it/s]
 59%|█████▉    | 594/1001 [00:00<00:00, 1181.59it/s]
 72%|███████▏  | 720/1001 [00:00<00:00, 1193.31it/s]
 84%|████████▍ | 844/1001 [00:00<00:00, 1199.53it/s]
 96%|█████████▋| 964/1001 [00:00<00:00, 1199.42it/s]
100%|██████████| 1001/1001 [00:00<00:00, 1197.29it/s]
Out[222]:

Train the generator's variance network

First, find the centers of the latent representations.

In [223]:
# Find the centers of the latent representations
kmeans_model = KMeans(n_clusters=num_centers, random_state=0)
kmeans_model = kmeans_model.fit(encoded_mean)
centers = kmeans_model.cluster_centers_

# Visualize the centers
center_plot = go.Scatter(
    x = centers[:, 0],
    y = centers[:, 1],
    mode = 'markers',
    marker = {'color': 'red'}
)
data = [scatter_plot, center_plot] 
iplot(go.Figure(data=data, layout=layout), config=config)

Compute the bandwidths

In [224]:
# Cluster the latent representations
clustering = dict((c_i, []) for c_i in range(num_centers))
for z_i, c_i in zip(encoded_mean, kmeans_model.predict(encoded_mean)):
    clustering[c_i].append(z_i)
    
bandwidths = []
for c_i, cluster in clustering.items():
    if cluster:
        diffs = np.array(cluster) - centers[c_i]
        avg_dist = np.mean(np.linalg.norm(diffs, axis=1))
        bandwidth = 0.5 / (a * avg_dist)**2
    else:
        bandwidth = 0
    bandwidths.append(bandwidth)
bandwidths = np.array(bandwidths)

Train the generator's variance network

while keeping all other parameters of the VAE fixed, as described in the paper.

In [225]:
# Train the RBF
vae.recompile_for_var_training()
rbf_kernel = rbf.get_weights()[0]
rbf.set_weights([rbf_kernel, centers, bandwidths])

history = vae.model.fit(x_train,
                        epochs=1,
                        batch_size=32,
                        validation_split=0.1,
                        verbose=0)
WARNING:tensorflow:Output "model_182" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_182" during training.
WARNING:tensorflow:Output "model_182" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_182" during training.
WARNING:tensorflow:Output "model_182" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "model_182" during training.
In [226]:
# Extract the mean and std predictors
from src.util import wrap_model_in_float64
_, mean, var = vae.decoder.output
std = Lambda(tf.sqrt)(var)
dec_mean = Model(vae.decoder.input, Flatten()(mean))
dec_std = Model(vae.decoder.input, Flatten()(std))
dec_mean = wrap_model_in_float64(dec_mean)
dec_std = wrap_model_in_float64(dec_std)

Show a heatmap of magnification factors

In [227]:
axis_length = max(abs(encoded_mean.min()), encoded_mean.max()) + 10
heatmap_z1 = np.linspace(-axis_length, axis_length, 200 if not debug else 3)
heatmap_z2 = np.linspace(-axis_length, axis_length, 200 if not debug else 3)
heatmap = plot_magnification_factor(K.get_session(), 
                                    heatmap_z1,
                                    heatmap_z2, 
                                    dec_mean, 
                                    dec_std, 
                                    additional_data=[scatter_plot],
                                    layout=layout,
                                    log_scale=True,
                                    scale='hotcold')
Computing Magnification Factors:   0%|          | 0/40000 [00:00<?, ?it/s]
Computing Magnification Factors:   0%|          | 1/40000 [00:01<16:49:08,  1.51s/it]
Computing Magnification Factors:   0%|          | 10/40000 [00:01<1:48:06,  6.16it/s]
Computing Magnification Factors:   0%|          | 19/40000 [00:01<1:00:39, 10.99it/s]
Computing Magnification Factors:   0%|          | 28/40000 [00:01<43:41, 15.25it/s]  
Computing Magnification Factors:   0%|          | 37/40000 [00:01<34:59, 19.03it/s]
Computing Magnification Factors:   0%|          | 46/40000 [00:02<29:41, 22.43it/s]
Computing Magnification Factors:   0%|          | 55/40000 [00:02<26:08, 25.47it/s]
Computing Magnification Factors:   0%|          | 63/40000 [00:02<23:54, 27.84it/s]
Computing Magnification Factors:   0%|          | 72/40000 [00:02<21:55, 30.36it/s]
Computing Magnification Factors:   0%|          | 81/40000 [00:02<20:23, 32.64it/s]
Computing Magnification Factors:   0%|          | 89/40000 [00:02<19:25, 34.24it/s]
Computing Magnification Factors:   0%|          | 97/40000 [00:02<18:32, 35.86it/s]
Computing Magnification Factors:   0%|          | 106/40000 [00:02<17:38, 37.70it/s]
Computing Magnification Factors:   0%|          | 114/40000 [00:02<16:59, 39.11it/s]
Computing Magnification Factors:   0%|          | 122/40000 [00:03<16:25, 40.44it/s]
Computing Magnification Factors:   0%|          | 130/40000 [00:03<16:01, 41.46it/s]
Computing Magnification Factors: 100%|██████████| 40000/40000 [07:59<00:00, 83.49it/s]

Take a random walk

Define the Riemannian metric and the Euclidean metric (which is the identity matrix).

In [228]:
# Let's take a random walk
from tqdm import tqdm
from src.util import get_metric_op, get_numerical_jacobian
session = K.get_session()

def jac_fun(output_tensor, input_tensor):
    return get_numerical_jacobian(session, output_tensor, input_tensor)

# Build the riemannian function
point = tf.placeholder(tf.float64, [2])
metric_op = get_metric_op(point, dec_mean, dec_std, jac_fun=jac_fun)

def get_riemannian(position):
    return session.run(metric_op, feed_dict={point: position})
    
def get_euclidean(position):
    return np.eye(len(position))
In [242]:
def random_walk(metric_fun, num_steps=1000, step_size=1.):
    position = np.array([0., 0.])
    walk = [np.copy(position)]
    for _ in tqdm(range(num_steps), 'Taking Random Walk'):
        metric = metric_fun(position)
        eigvals, eigvecs = np.linalg.eig(metric)
        noise = np.random.randn(2)
        v = (eigvecs * (eigvals ** -0.5)).dot(noise)
        position += step_size * v
        walk.append(np.copy(position))

    return np.vstack(walk)

riemannian_walk = random_walk(get_riemannian)
euclidean_walk = random_walk(get_euclidean)

Visualize the random walk

In [243]:
riemannian_plot = go.Scatter(
    x=riemannian_walk[:, 0],
    y=riemannian_walk[:, 1],
    mode='lines',
    line={'width': 1, 'color': 'green'}
)

euclidean_plot = go.Scatter(
    x=euclidean_walk[:, 0],
    y=euclidean_walk[:, 1],
    mode='lines',
    line={'width': 1, 'color': '#ff005a'}
)
data = [scatter_plot, euclidean_plot, riemannian_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
Taking Random Walk: 100%|██████████| 1000/1000 [00:12<00:00, 77.58it/s]
Taking Random Walk: 100%|██████████| 1000/1000 [00:00<00:00, 8039.99it/s]

Compare the reconstructions

at different steps of the random walks.

In [244]:
from src.plot import plot_images

# Visualize some steps in the random walks
steps = [0, 200, 300, 800, 900, 1000]
images = {}
for step in steps:
    euclidean_position = euclidean_walk[step]
    riemannian_position = riemannian_walk[step]
    _, (euclidean_frame, riemannian_frame), _ = vae.decoder.predict(np.array([
        euclidean_position,
        riemannian_position
    ]))
    
    euclidean_frame += x_mean
    riemannian_frame += x_mean
    euclidean_frame = np.clip(euclidean_frame, 0, 1) * 255.0
    riemannian_frame = np.clip(riemannian_frame, 0, 1) * 255.0
    
    # Scale from 64x64 per image to 256x256
    euclidean_frame = imresize(euclidean_frame, 4.0, 'nearest')
    riemannian_frame = imresize(riemannian_frame, 4.0, 'nearest')
    images['step %d euclidean' % step] = euclidean_frame
    images['step %d riemannian' % step] = riemannian_frame
plot_images(images, nrows=6, ncols=2)